{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "sess = tf.Session() \n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# dataset 读取csv "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_example(line):\n",
    "    columns = tf.io.decode_csv(line, record_defaults= CSV_RECORD_DEFAULTS)\n",
    "\n",
    "    features = dict(zip(FEATURE_NAME, columns))\n",
    "\n",
    "    target = tf.reshape(tf.cast(tf.equal(features.pop(TARGET),'>50K'),  tf.int32),[-1,1])\n",
    "\n",
    "    return features, target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_path = DATA_DIR.format('train')\n",
    "dataset = tf.data.TextLineDataset(input_path) \\\n",
    "    .skip(1) \\\n",
    "    .batch(5) \\\n",
    "    .map( parse_example, num_parallel_calls=8 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "({'age': array([50, 38, 53, 28, 37], dtype=int32),\n",
       "  'workclass': array([b'Self-emp-not-inc', b'Private', b'Private', b'Private',\n",
       "         b'Private'], dtype=object),\n",
       "  'fnlwgt': array([ 83311, 215646, 234721, 338409, 284582], dtype=int32),\n",
       "  'education': array([b'Bachelors', b'HS-grad', b'11th', b'Bachelors', b'Masters'],\n",
       "        dtype=object),\n",
       "  'education_num': array([13,  9,  7, 13, 14], dtype=int32),\n",
       "  'marital_status': array([b'Married-civ-spouse', b'Divorced', b'Married-civ-spouse',\n",
       "         b'Married-civ-spouse', b'Married-civ-spouse'], dtype=object),\n",
       "  'occupation': array([b'Exec-managerial', b'Handlers-cleaners', b'Handlers-cleaners',\n",
       "         b'Prof-specialty', b'Exec-managerial'], dtype=object),\n",
       "  'relationship': array([b'Husband', b'Not-in-family', b'Husband', b'Wife', b'Wife'],\n",
       "        dtype=object),\n",
       "  'race': array([b'White', b'White', b'Black', b'Black', b'White'], dtype=object),\n",
       "  'gender': array([b'Male', b'Male', b'Male', b'Female', b'Female'], dtype=object),\n",
       "  'capital_gain': array([0, 0, 0, 0, 0], dtype=int32),\n",
       "  'capital_loss': array([0, 0, 0, 0, 0], dtype=int32),\n",
       "  'hours_per_week': array([13, 40, 40, 40, 40], dtype=int32),\n",
       "  'native_country': array([b'United-States', b'United-States', b'United-States', b'Cuba',\n",
       "         b'United-States'], dtype=object)},\n",
       " array([[0],\n",
       "        [0],\n",
       "        [0],\n",
       "        [0],\n",
       "        [0]], dtype=int32))"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterator = dataset.make_one_shot_iterator()\n",
    "features, label = sess.run(iterator.get_next())\n",
    "features, label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# feature_column示例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- feature_column的输入支持：列名，feature_column\n",
    "- feature_column的输出：dense, sparse"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 连续型输入"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. numeric_column: 输入时原始连续特征，输出是numeric_column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[50.],\n",
       "       [38.],\n",
       "       [53.],\n",
       "       [28.],\n",
       "       [37.]], dtype=float32)"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "age = tf.feature_column.numeric_column('age')\n",
    "sess.run(tf.feature_column.input_layer(features, age))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. bucketized_column: 输入是numeric_column, 输出是one-hot格式"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 注意bucketized不能对原始输入做操作必须先转成numeric_column\n",
    "- boundaries=[0., 1., 2.] 分桶是 (-inf, 0.), [0., 1.), [1., 2.), and [2., +inf)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n",
       "       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]], dtype=float32)"
      ]
     },
     "execution_count": 112,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "age_bin = tf.feature_column.bucketized_column(age, boundaries = [18, 25, 30, 35, 40, 45, 50, 55, 60, 65])\n",
    "sess.run(tf.feature_column.input_layer(features, age_bin))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 离散型输入：可以是数值/字符串型离散值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. categorical_column_with_vocabulary_list/file: 输入是数量较少的离散值(str/int)，输出categorical_column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VocabularyListCategoricalColumn(key='gender', vocabulary_list=('Male', 'Female'), dtype=tf.string, default_value=-1, num_oov_buckets=0)"
      ]
     },
     "execution_count": 120,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_gender = tf.feature_column.categorical_column_with_vocabulary_list('gender', ['Male','Female'])\n",
    "cat_gender"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. categorical_column_with_hash_bucket: 输入是数量较多的离散值(str/int), 输出categorical_column"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 往往用于vocabulary_list未知或者太多不好枚举，这时hash_bucket_size < vocabulary_list，部分单词会有相同的hash值\n",
    "- 但有时懒得去写上面的vocabulary_list，这时只要hash_bucket_size > vocabulary_list那和上面效果是一样的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "HashedCategoricalColumn(key='workclass', hash_bucket_size=100, dtype=tf.string)"
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_workclass = tf.feature_column.categorical_column_with_hash_bucket('workclass', hash_bucket_size = 100)\n",
    "cat_workclass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. categorical_column_with_identity：输入是int类型的离散值, 输出categorical_column"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 离散值的取值范围在[0, num_buckets],每个int会被当作一个categorical\n",
    "- 取值范围外的会用default_value，未指定default_value会报错， default_value必须在上述范围内。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "IdentityCategoricalColumn(key='education_num', number_buckets=20, default_value=0)"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_education_num = tf.feature_column.categorical_column_with_identity('education_num', num_buckets= 20, default_value=0)\n",
    "cat_education_num"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 特征交叉/映射"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. crossed_column：输入是原始离散特征/bucketized_column/categorical_column, 输出sparse vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 输入原始离散特征\n",
    "edu_occu= tf.feature_column.crossed_column(['education','occupation'], hash_bucket_size =10) \n",
    "# 输入categorical_column\n",
    "workclass_gender = tf.feature_column.crossed_column([cat_gender,cat_education_num], hash_bucket_size =10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. embedding_column: 输入是categorical_column, 输出是dense vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_gender = tf.feature_column.embedding_column(cat_gender, dimension=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'sess' is not defined",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-9348e7a5f25d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtables_initializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mddddglobal_variables_initializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_column\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput_layer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0memb_gender\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'sess' is not defined"
     ],
     "output_type": "error"
    }
   ],
   "source": [
    "sess.run(tf.tables_initializer())\n",
    "sess.run(tf.ddddglobal_variables_initializer())\n",
    "sess.run(tf.feature_column.input_layer(features, emb_gender))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. indicator_column: 输入是categorical_column/crossed_column, 输出是one-hot vector "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'cat_gender' is not defined",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-3e248e555c71>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcat_gender\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'cat_gender' is not defined"
     ],
     "output_type": "error"
    }
   ],
   "source": [
    "cat_gender"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'tf' is not defined",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-3-b116fcb23207>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mind_gender\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_column\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindicator_column\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcat_gender\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_column\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput_layer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mind_gender\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'tf' is not defined"
     ],
     "output_type": "error"
    }
   ],
   "source": [
    "ind_gender = tf.feature_column.indicator_column(cat_gender)\n",
    "inputs = tf.feature_column.input_layer(features, ind_gender)\n",
    "sess.run(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Categorical_column 输入"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# feature_column to pre-defined tf.estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
